Прогнозирование заказов такси¶
Аггрегатор такси собрал исторические данные о заказах такси в аэропортах. Чтобы привлекать больше водителей в период пиковой нагрузки, нужно спрогнозировать количество заказов такси на следующий час. Необходимо построить модель для такого предсказания.
Значение метрики RMSE на тестовой выборке должно быть не больше 48.
Данные лежат в файле taxi.csv. Количество заказов находится в столбце num_orders.
Подготовка¶
# Установка библиотек
!pip install phik
!pip install --upgrade scikit-learn
!pip install catboost
Requirement already satisfied: phik in /Users/NOrlov/myenv/lib/python3.12/site-packages (0.12.4) Requirement already satisfied: numpy>=1.18.0 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from phik) (1.26.4) Requirement already satisfied: scipy>=1.5.2 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from phik) (1.13.0) Requirement already satisfied: pandas>=0.25.1 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from phik) (2.2.1) Requirement already satisfied: matplotlib>=2.2.3 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from phik) (3.8.4) Requirement already satisfied: joblib>=0.14.1 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from phik) (1.3.2) Requirement already satisfied: contourpy>=1.0.1 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from matplotlib>=2.2.3->phik) (1.2.1) Requirement already satisfied: cycler>=0.10 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from matplotlib>=2.2.3->phik) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from matplotlib>=2.2.3->phik) (4.50.0) Requirement already satisfied: kiwisolver>=1.3.1 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from matplotlib>=2.2.3->phik) (1.4.5) Requirement already satisfied: packaging>=20.0 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from matplotlib>=2.2.3->phik) (24.0) Requirement already satisfied: pillow>=8 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from matplotlib>=2.2.3->phik) (10.3.0) Requirement already satisfied: pyparsing>=2.3.1 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from matplotlib>=2.2.3->phik) (3.1.2) Requirement already satisfied: python-dateutil>=2.7 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from matplotlib>=2.2.3->phik) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from pandas>=0.25.1->phik) (2024.1) Requirement already satisfied: tzdata>=2022.7 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from pandas>=0.25.1->phik) (2024.1) Requirement already satisfied: six>=1.5 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib>=2.2.3->phik) (1.16.0) Requirement already satisfied: scikit-learn in /Users/NOrlov/myenv/lib/python3.12/site-packages (1.4.2) Requirement already satisfied: numpy>=1.19.5 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from scikit-learn) (1.26.4) Requirement already satisfied: scipy>=1.6.0 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from scikit-learn) (1.13.0) Requirement already satisfied: joblib>=1.2.0 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from scikit-learn) (1.3.2) Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from scikit-learn) (3.4.0) Requirement already satisfied: catboost in /Users/NOrlov/myenv/lib/python3.12/site-packages (1.2.3) Requirement already satisfied: graphviz in /Users/NOrlov/myenv/lib/python3.12/site-packages (from catboost) (0.20.3) Requirement already satisfied: matplotlib in /Users/NOrlov/myenv/lib/python3.12/site-packages (from catboost) (3.8.4) Requirement already satisfied: numpy>=1.16.0 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from catboost) (1.26.4) Requirement already satisfied: pandas>=0.24 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from catboost) (2.2.1) Requirement already satisfied: scipy in /Users/NOrlov/myenv/lib/python3.12/site-packages (from catboost) (1.13.0) Requirement already satisfied: plotly in /Users/NOrlov/myenv/lib/python3.12/site-packages (from catboost) (5.20.0) Requirement already satisfied: six in /Users/NOrlov/myenv/lib/python3.12/site-packages (from catboost) (1.16.0) Requirement already satisfied: python-dateutil>=2.8.2 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from pandas>=0.24->catboost) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from pandas>=0.24->catboost) (2024.1) Requirement already satisfied: tzdata>=2022.7 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from pandas>=0.24->catboost) (2024.1) Requirement already satisfied: contourpy>=1.0.1 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from matplotlib->catboost) (1.2.1) Requirement already satisfied: cycler>=0.10 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from matplotlib->catboost) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from matplotlib->catboost) (4.50.0) Requirement already satisfied: kiwisolver>=1.3.1 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from matplotlib->catboost) (1.4.5) Requirement already satisfied: packaging>=20.0 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from matplotlib->catboost) (24.0) Requirement already satisfied: pillow>=8 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from matplotlib->catboost) (10.3.0) Requirement already satisfied: pyparsing>=2.3.1 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from matplotlib->catboost) (3.1.2) Requirement already satisfied: tenacity>=6.2.0 in /Users/NOrlov/myenv/lib/python3.12/site-packages (from plotly->catboost) (8.2.3)
# Импорт библиотек
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
from sklearn.metrics import root_mean_squared_error, mean_squared_error
from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score, TimeSeriesSplit
from sklearn.linear_model import LinearRegression
from sklearn.dummy import DummyRegressor
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from phik import phik_matrix
from catboost import CatBoostRegressor
import lightgbm as lgb
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.graphics.tsaplots import plot_acf
import time
import os
# Загрузка данных
try:
path = '/datasets/taxi.csv'
data = pd.read_csv(path, index_col=[0], parse_dates=[0])
except FileNotFoundError:
current_directory = os.getcwd()
data = pd.read_csv(current_directory + path, index_col=[0], parse_dates=[0])
display(data.info())
# Назначение индексов
data.sort_index(inplace=True)
# Поиск пропусков
print()
print('Пропуски:', data.isna().sum())
print()
#Ресэмплинг почасово
data = data.resample('1H').sum()
<class 'pandas.core.frame.DataFrame'> DatetimeIndex: 26496 entries, 2018-03-01 00:00:00 to 2018-08-31 23:50:00 Data columns (total 1 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 num_orders 26496 non-null int64 dtypes: int64(1) memory usage: 414.0 KB
None
Пропуски: num_orders 0 dtype: int64
/var/folders/2r/zfmtb66x3rs14ctpbj7lhdt40000gp/T/ipykernel_19445/1415298997.py:15: FutureWarning: 'H' is deprecated and will be removed in a future version, please use 'h' instead.
data = data.resample('1H').sum()
# Вывод датафрейма
display(data.sample(5))
# Визуализация
plt.figure(figsize=(16, 8))
plt.plot(data.index, data['num_orders'], linestyle='-', color='b')
plt.ylabel('Order number per hour')
plt.title('Order numbers per hour');
| num_orders | |
|---|---|
| datetime | |
| 2018-07-07 00:00:00 | 190 |
| 2018-03-17 20:00:00 | 57 |
| 2018-04-04 11:00:00 | 66 |
| 2018-06-05 18:00:00 | 60 |
| 2018-04-30 10:00:00 | 120 |
Выводы¶
Загружен датасет с историческими данными. Произведена первичная обработка и ресемплинг.
Анализ¶
Первичный анализ¶
# Общие сведения о заказах
data.describe()
| num_orders | |
|---|---|
| count | 4416.000000 |
| mean | 84.422781 |
| std | 45.023853 |
| min | 0.000000 |
| 25% | 54.000000 |
| 50% | 78.000000 |
| 75% | 107.000000 |
| max | 462.000000 |
# Гистограмма
plt.figure(figsize = (16,8))
plt.hist(data, bins = 100)
plt.ylabel('Frequency')
plt.xlabel('Order numbers')
plt.title('Order number hist')
plt.grid(True);
# Boxplot
plt.figure(figsize = (16,8))
plt.boxplot(data)
plt.ylabel('Order numbers per hour')
plt.title('Order number hist')
plt.grid(True);
data[data['num_orders']>250].sample(5)
| num_orders | |
|---|---|
| datetime | |
| 2018-08-26 00:00:00 | 281 |
| 2018-08-18 02:00:00 | 254 |
| 2018-08-31 04:00:00 | 268 |
| 2018-08-18 00:00:00 | 281 |
| 2018-04-24 00:00:00 | 253 |
Наблюдается некоторое количество выбросов, в основном они происходят в вечернее и ночное время.
Анализ трендов и сезонности¶
decomposed = seasonal_decompose(data['num_orders'].dropna())
fig = go.Figure()
fig.add_trace(go.Scatter(x=data['num_orders'].dropna().index, y=data['num_orders'].dropna(),
mode='lines',
name='OriginalData'))
fig.add_trace(go.Scatter(x = decomposed.resid.index, y = decomposed.resid,
mode='lines',
name='Resid'))
fig.add_trace(go.Scatter(x = decomposed.seasonal.index, y = decomposed.seasonal,
mode='lines',
name='Season'))
fig.add_trace(go.Scatter(x = decomposed.trend.index, y = decomposed.trend,
mode='lines',
name='Trend'))
# Фурье преобразование сезонной компоненты
T = 1 # Интервал между отсчетами 1 час
L = len(decomposed.seasonal) # Длина датасета
t = np.arange(0, L*T, T) # Временная ось, часы
# Выполнение FFT
N = len(data)
Y = np.fft.fft(decomposed.seasonal)
Y = Y/np.max(Y)
Y = Y[range(N//2)] # Только положительные частоты
# Создание временной оси
frequencies = 1/np.arange(0, (1/(2*T)), (1/(2*T))/len(Y))
fig = go.Figure()
fig.add_trace(go.Scatter(x=frequencies, y=np.abs(Y), mode='lines', name='FFT'))
fig.update_layout(title='Анализ временных интервалов "сезонной" компоненты данных',
xaxis_title='Периоды, часы',
yaxis_title='Относительная амплитуда')
fig.show();
/var/folders/2r/zfmtb66x3rs14ctpbj7lhdt40000gp/T/ipykernel_19445/122445324.py:16: RuntimeWarning: divide by zero encountered in divide
# Построение коррелограммы сезонной компоненты
corr_data_seasonal = decomposed.seasonal.dropna()
fig, ax = plt.subplots(figsize=(16, 8))
# Построение коррелограммы с использованием fig и ax
plot_acf(corr_data_seasonal, lags=24*14, ax=ax) # lags - количество лагов для отображения
plt.xlabel('Лаги, часы')
plt.ylabel('Автокорреляция')
plt.grid(True)
plt.title('Коррелограмма остатков')
Text(0.5, 1.0, 'Коррелограмма остатков')
# Построение коррелограммы остаточной компоненты
corr_data_resid = decomposed.resid.dropna()
fig, ax = plt.subplots(figsize=(16, 8))
# Построение коррелограммы с использованием fig и ax
plot_acf(corr_data_resid, lags=24*14, ax=ax) # lags - количество лагов для отображения
plt.xlabel('Лаги, часы')
plt.ylabel('Автокорреляция')
plt.grid(True)
plt.title('Коррелограмма остаточной компоненты')
print('RMS остаточной компоненты:', np.std(corr_data_resid))
RMS остаточной компоненты: 24.694084222676633
Выводы¶
В ходе временного и частотного анализа удалось выявить следующие закономерности.
Тренд восходящий, это значит, что количество заказов на выбранном интервале времени датасета возрастает со временем и ряд нестационарен.
На Фурье преобразовании сезонной компоненты наиболее заметны следующие интервалы: 12 часов, 24 часа, 8 часов, 6 часов, 4 часа. В целом, они выглядят логичными и их можно попробовать увязать с внутридневными циклами жизни людей. На кореллограмме сезонной компоненты заметны пики с регулярностью 24 часа, что совпадает с результатами Фурье преобразования.
Согласно анализу суточных периодов, локальные максимумы заказов наблюдаются в 2, 10, 16 и 24 часа, локальные минимумы в 6, 12 и 18 часов. Судя по всему, в этот аэропорт по большей части совершается большое количество регулярных рейсов ежедневно.
Компонента остатков содержит на кореллограмме суточные пики, а также еженедельные.
Компонента остатков выглядит достаточно равномерной с выбросами, учащающимися и увеличивающимися со временем, что согласуется с растущим трендом.
Feature engineering¶
# Создадим функцию, которое добавляет столбцы с часами, числом, днем недели, месяцем, а также прошлые данные и скользящее среднее
def make_features(data, max_lag, rolling_mean_size):
data['hour'] = data.index.hour
data['day'] = data.index.day
data['dayofweek'] = data.index.dayofweek
for lag in range(1, max_lag + 1):
data['lag_{}'.format(lag)] = data['num_orders'].shift(lag)
data['rolling_mean'] = data['num_orders'].shift().rolling(rolling_mean_size).mean()
# Создаем новые фичи в датасете с помощью функции
make_features(data, 8, 24)
data_final = data.dropna()
data_final.sample(15)
| num_orders | hour | day | dayofweek | lag_1 | lag_2 | lag_3 | lag_4 | lag_5 | lag_6 | lag_7 | lag_8 | rolling_mean | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| datetime | |||||||||||||
| 2018-06-17 20:00:00 | 71 | 20 | 17 | 6 | 78.0 | 37.0 | 124.0 | 156.0 | 91.0 | 91.0 | 50.0 | 88.0 | 80.291667 |
| 2018-05-13 07:00:00 | 13 | 7 | 13 | 6 | 23.0 | 24.0 | 61.0 | 155.0 | 71.0 | 59.0 | 95.0 | 71.0 | 65.041667 |
| 2018-06-10 04:00:00 | 88 | 4 | 10 | 6 | 106.0 | 101.0 | 83.0 | 194.0 | 114.0 | 110.0 | 88.0 | 119.0 | 90.125000 |
| 2018-07-03 19:00:00 | 95 | 19 | 3 | 1 | 78.0 | 108.0 | 130.0 | 95.0 | 79.0 | 58.0 | 81.0 | 93.0 | 85.958333 |
| 2018-07-10 18:00:00 | 74 | 18 | 10 | 1 | 108.0 | 120.0 | 80.0 | 97.0 | 77.0 | 72.0 | 98.0 | 103.0 | 90.416667 |
| 2018-07-17 17:00:00 | 120 | 17 | 17 | 1 | 57.0 | 97.0 | 74.0 | 49.0 | 66.0 | 121.0 | 79.0 | 108.0 | 86.833333 |
| 2018-04-02 11:00:00 | 82 | 11 | 2 | 0 | 111.0 | 57.0 | 62.0 | 7.0 | 20.0 | 33.0 | 33.0 | 47.0 | 62.791667 |
| 2018-03-09 04:00:00 | 30 | 4 | 9 | 4 | 31.0 | 29.0 | 66.0 | 84.0 | 79.0 | 88.0 | 86.0 | 114.0 | 63.125000 |
| 2018-03-11 23:00:00 | 118 | 23 | 11 | 6 | 89.0 | 119.0 | 99.0 | 99.0 | 107.0 | 78.0 | 93.0 | 60.0 | 64.791667 |
| 2018-04-30 14:00:00 | 49 | 14 | 30 | 0 | 62.0 | 49.0 | 63.0 | 120.0 | 91.0 | 69.0 | 21.0 | 32.0 | 77.000000 |
| 2018-08-13 05:00:00 | 83 | 5 | 13 | 0 | 143.0 | 168.0 | 437.0 | 273.0 | 160.0 | 185.0 | 118.0 | 135.0 | 130.500000 |
| 2018-05-13 22:00:00 | 88 | 22 | 13 | 6 | 101.0 | 94.0 | 70.0 | 48.0 | 117.0 | 128.0 | 54.0 | 71.0 | 68.916667 |
| 2018-06-11 08:00:00 | 58 | 8 | 11 | 0 | 17.0 | 46.0 | 68.0 | 101.0 | 116.0 | 113.0 | 124.0 | 152.0 | 83.125000 |
| 2018-07-29 21:00:00 | 84 | 21 | 29 | 6 | 114.0 | 89.0 | 89.0 | 131.0 | 178.0 | 119.0 | 135.0 | 116.0 | 106.333333 |
| 2018-05-28 02:00:00 | 186 | 2 | 28 | 0 | 78.0 | 157.0 | 113.0 | 111.0 | 77.0 | 78.0 | 84.0 | 37.0 | 72.541667 |
# Рассмотрим распределение данных, сгруппированных по отдельным признакам
# по часам
hour = data_final.groupby(by = 'hour').mean()['num_orders']
plt.figure(figsize = (16,8))
plt.bar(hour.index, hour)
plt.title('Hour distribution')
plt.xlabel('Hour')
plt.ylabel('Mean order value')
plt.show()
# по дням недели
week = data_final.groupby(by = 'dayofweek').mean()['num_orders']
plt.figure(figsize = (16,8))
plt.bar(week.index, week)
plt.title('Day of Week distribution')
plt.xlabel('Day of Week')
plt.ylabel('Mean order value')
plt.show()
# по числу
day = data_final.groupby(by = 'day').mean()['num_orders']
plt.figure(figsize = (16,8))
plt.bar(day.index, day)
plt.title('Day')
plt.xlabel('Day')
plt.ylabel('Mean order value')
plt.show()
# Выбираем интересующие нас числовые признаки
interval_cols = ['num_orders']
# Вычисляем матрицу корреляции с использованием phik
phik_corr_matrix = phik_matrix(data_final[['num_orders', 'dayofweek', 'hour', 'day']], interval_cols=interval_cols)
plt.figure(figsize=(18, 10))
sns.heatmap(phik_corr_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
plt.title('PhiK Correlation Coefficients Heatmap')
plt.show();
Выводы¶
Корреляционный анализ и столбчатые диаграммы продемонстрировали схожие результаты зависимости целевого признака и новых фичей. Количество заказов коррелирует со временем, растет со временем и слабо зависит от числа и дня недели.
Обучение¶
Создание тестовой и тренировочной выборок¶
# Data split
train, test = train_test_split(data_final, shuffle=False, test_size=0.1, random_state=42)
train = train.dropna()
y_train = train['num_orders']
X_train = train.drop('num_orders', axis = 1)
y_test = test['num_orders']
X_test = test.drop('num_orders', axis = 1)
CatBoostRegressor model¶
# Создаем модель CatBoostRegressor
model_cb = CatBoostRegressor(random_state = 41, verbose = False)
# Задаем сетку параметров для GridSearchCV
param_grid = {
'learning_rate': [0.01, 0.03, 0.1],
'depth': [3, 6, 8],
'l2_leaf_reg': [0.1, 1, 5],
'iterations': [500]
}
# Инициализируем TimeSeriesSplit для кросс-валидации временных рядов
tscv = TimeSeriesSplit(n_splits=3)
grid_search_cb = GridSearchCV(estimator=model_cb, param_grid=param_grid, scoring='neg_mean_squared_error', cv=tscv, verbose=1, n_jobs=-1, error_score='raise')
grid_search_cb.fit(X_train, y_train)
best_params_cb = grid_search_cb.best_params_
best_model_cb = grid_search_cb.best_estimator_
Fitting 3 folds for each of 27 candidates, totalling 81 fits
# Результаты обучения и получение RMSE метрик на тестовом и тренировочном наборах
print("Best Hyperparameters CatBoostRegressor:", best_params_cb)
print("Best CatBoostRegressor model CV RMSE:", best_model_cb.best_score_)
y_pred_train_cb = best_model_cb.predict(X_train)
rmse_train_cb = mean_squared_error(y_train, y_pred_train_cb, squared = False)
print("Best CatBoostRegressor model Train RMSE:", rmse_train_cb)
print("Best CatBoostRegressor model Train mean residual:", np.mean(y_train - y_pred_train_cb))
Best Hyperparameters CatBoostRegressor: {'depth': 3, 'iterations': 500, 'l2_leaf_reg': 1, 'learning_rate': 0.1}
Best CatBoostRegressor model CV RMSE: {'learn': {'RMSE': 18.228244040730306}}
Best CatBoostRegressor model Train RMSE: 18.228244040730313
Best CatBoostRegressor model Train mean residual: 0.0023369308604390755
/Users/NOrlov/myenv/lib/python3.12/site-packages/sklearn/metrics/_regression.py:483: FutureWarning: 'squared' is deprecated in version 1.4 and will be removed in 1.6. To calculate the root mean squared error, use the function'root_mean_squared_error'.
# Замер времени обучения лучшей модели
start_time = time.time()
time_model_cb = best_model_cb
time_model_cb.fit(X_train, y_train)
training_time_cb = time.time() - start_time
# Замер времени предсказания лучшей модели
start_time = time.time()
y_pred_cb = time_model_cb.predict(X_train)
prediction_time_cb = time.time() - start_time
print('CatBoostRegressor model training time, s:', training_time_cb)
print('CatBoostRegressor model prediction time, s:', prediction_time_cb)
CatBoostRegressor model training time, s: 0.29205894470214844 CatBoostRegressor model prediction time, s: 0.0012660026550292969
# Визуализация y_true, y_pred, resid и их разницы для наглядности
fig = go.Figure()
fig.add_trace(go.Scatter(x=X_train.index, y=y_train,
mode='lines',
name='Original data'))
fig.add_trace(go.Scatter(x = X_train.index, y = y_pred_train_cb,
mode='lines',
name='CB prediction'))
fig.add_trace(go.Scatter(x = X_train.index, y = decomposed.resid.loc[X_train.index],
mode='lines',
name='Resid'))
fig.add_trace(go.Scatter(x = X_train.index, y = decomposed.resid.loc[X_train.index] + y_pred_train_cb.mean() - y_pred_train_cb,
mode='lines',
name='Resid-error'))
fig.update_layout(xaxis_tickformat='%Y-%m-%d', xaxis_title='Datetime')
fig.update_layout(width=1000, height=800)
fig.show()
# Построение гистограммы остатков на Train наборе
plt.figure(figsize = (16,8))
plt.hist(y_train - y_pred_train_cb, bins=100, label = 'model residual')
plt.hist(decomposed.resid.loc[X_train.index], bins=100, alpha = 0.7, label = 'residual component')
plt.xlabel('y_true - y_pred_train')
plt.ylabel('Frequency')
plt.axvline(x=0, color='r', linestyle='-', linewidth=1, label = 'zero')
plt.axvline(x=np.mean(y_train - y_pred_train_cb), color='b', linestyle='-', linewidth=2, alpha = 0.2, label = 'residual mean')
plt.legend()
plt.title('CatBoostRegressor residual hist')
plt.grid('true')
plt.show();
# Поиск наиболее и наименее значимых фичей
all_feature_names = X_train.columns
feature_importances = best_model_cb.feature_importances_
feature_importance_df = pd.DataFrame({
'Feature': all_feature_names,
'Importance': feature_importances
})
# Максимальные признаки
feature_importance_df_max = feature_importance_df.sort_values(
by='Importance', ascending=False
)
plt.figure(figsize=(12, 8))
plt.barh(feature_importance_df_max['Feature'], feature_importance_df_max['Importance'], color='skyblue')
plt.xlabel('Importance')
plt.ylabel('Features')
plt.title('Maximal Feature Importance CatBoostRegressor')
plt.gca().invert_yaxis()
plt.show()
CatBoostRegressor Выводы¶
Модель с лучшими гиперпараметрами продемонстрировала на кросс-валидации метрику RMSE лучше порогового значения.
Анализ остатков показал нормальное распределение, остатки на тренировочной выборке имеют нулевое среднее. Остаточная компонента и остатки модели имеют большие коррелирующие между собой единичные выбросы.
Анализ важности фичей подтвердил адекватность модели: время суток, информация предыдщего часа и скользящее среднее больше всего влияют на предсказание. Данные с большой задержкой и календарное число оказывают наименьшее влияние на результат предсказания.
LightGBM model¶
# Задаем сетку параметров для GridSearchCV
params_lgb = {
'boosting_type': 'gbdt',
'objective': 'regression',
'metric': 'rmse',
'max_depth': -1,
'min_child_samples': 20,
'subsample': 0.8,
'colsample_bytree': 0.8,
'reg_alpha': 0.0,
'reg_lambda': 0.0,
'verbose': 0,
'force_col_wise': True,
'n_estimators' : 500
}
param_grid_lgb = {
'num_leaves': [10, 20, 30],
'learning_rate': [0.01, 0.03, 0.1]
}
# Создаем модель LightGBM
model_lgb = lgb.LGBMRegressor(**params_lgb, random_state=42)
grid_search_lgb = GridSearchCV(estimator=model_lgb, param_grid=param_grid_lgb, scoring='neg_mean_squared_error', cv=tscv, verbose=1, n_jobs=-1, error_score='raise')
grid_search_lgb.fit(X_train, y_train)
best_params_lgb = grid_search_lgb.best_params_
best_model_lgb = grid_search_lgb.best_estimator_
Fitting 3 folds for each of 9 candidates, totalling 27 fits
# Результаты обучения и получение RMSE метрик на тренировочном наборе
print("Best Hyperparameters LightGBM:", best_params_lgb)
print("Best LightGBM model CV RMSE:", best_model_cb.best_score_)
y_pred_train_lgb = best_model_lgb.predict(X_train)
rmse_train_lgb = mean_squared_error(y_train, y_pred_train_lgb, squared = False)
print("Best LightGBM model Train RMSE:", rmse_train_lgb)
print("Best LightGBM Train mean residual:", np.mean(y_train - y_pred_train_lgb))
Best Hyperparameters LightGBM: {'learning_rate': 0.01, 'num_leaves': 30}
Best LightGBM model CV RMSE: {'learn': {'RMSE': 18.228244040730306}}
Best LightGBM model Train RMSE: 17.493857826582488
Best LightGBM Train mean residual: 1.5362727913915024e-10
/Users/NOrlov/myenv/lib/python3.12/site-packages/sklearn/metrics/_regression.py:483: FutureWarning: 'squared' is deprecated in version 1.4 and will be removed in 1.6. To calculate the root mean squared error, use the function'root_mean_squared_error'.
# Замер времени обучения лучшей модели
start_time = time.time()
time_model_lgb = best_model_lgb
time_model_lgb.fit(X_train, y_train)
training_time_lgb = time.time() - start_time
# Замер времени предсказания лучшей модели
start_time = time.time()
y_pred_lgb = time_model_lgb.predict(X_train)
prediction_time_lgb = time.time() - start_time
print('LightGBM model training time, s:', training_time_lgb)
print('LightGBM model prediction time, s:', prediction_time_lgb)
LightGBM model training time, s: 1.1269159317016602 LightGBM model prediction time, s: 0.027966022491455078
# Визуализация y_true, y_pred, resid и их разницы для наглядности
fig = go.Figure()
fig.add_trace(go.Scatter(x=X_train.index, y=y_train,
mode='lines',
name='Original data'))
fig.add_trace(go.Scatter(x = X_train.index, y = y_pred_train_lgb,
mode='lines',
name='CB prediction'))
fig.add_trace(go.Scatter(x = X_train.index, y = decomposed.resid.loc[X_train.index],
mode='lines',
name='Resid'))
fig.add_trace(go.Scatter(x = X_train.index, y = decomposed.resid.loc[X_train.index] + y_pred_train_lgb.mean() - y_pred_train_lgb,
mode='lines',
name='Resid-error'))
fig.update_layout(xaxis_tickformat='%Y-%m-%d', xaxis_title='Datetime')
fig.update_layout(width=1000, height=800)
fig.show()
# Построение гистограммы остатков
plt.figure(figsize = (16,8))
plt.hist(y_train - y_pred_train_lgb, bins=100, label = 'model residual')
plt.hist(decomposed.resid.loc[X_train.index], bins=100, alpha = 0.7, label = 'residual component')
plt.xlabel('y_true - y_pred_train')
plt.ylabel('Frequency')
plt.axvline(x=0, color='r', linestyle='-', linewidth=1, label = 'zero')
plt.axvline(x=np.mean(y_train - y_pred_train_lgb), color='b', linestyle='-', linewidth=2, alpha = 0.2, label = 'residual mean')
plt.legend()
plt.title('LightGBM residual hist')
plt.grid('true')
plt.show();
# Поиск наиболее и наименее значимых фичей
feature_importances = best_model_lgb.feature_importances_
feature_importance_df = pd.DataFrame({
'Feature': all_feature_names,
'Importance': feature_importances
})
# Максимальные признаки
feature_importance_df_max = feature_importance_df.sort_values(
by='Importance', ascending=False
)
plt.figure(figsize=(12, 8))
plt.barh(feature_importance_df_max['Feature'], feature_importance_df_max['Importance'], color='skyblue')
plt.xlabel('Importance')
plt.ylabel('Features')
plt.title('Maximal Feature Importance CatBoostRegressor')
plt.gca().invert_yaxis()
plt.show()
LightGBM Выводы¶
Модель LightGBM с лучшими гиперпараметрами продемонстрировала кросс-валидационную метрику RMSE лучше порогового.
Анализ остатков показал нормальное распределение, остатки на тренировочной выборке имеют нулевое среднее. Остаточная компонента и остатки модели имеют большие коррелирующие между собой единичные выбросы.
Анализ важности фичей подтвердил адекватность модели: время суток и информация предыдущего часа и скользящее среднее больше всего влияют на предсказание. С другой стороны, число и данные с большой задержкой оказывают наименьшее влияние на результат предсказания, что выглядит логично.
Выводы¶
Выбираем модель CatBoostRegressor в качестве финальной. Метрики RMSE двух моделей отличаются крайне незначительно, однако время обучения и время предсказания CatBoostRegressor на 2 порядка меньше, чем соответствующие времена LightGBM.
Анализ модели¶
# Результаты обучения и получение RMSE метрик на тестовом и тренировочном наборах
y_pred_test_cb = best_model_cb.predict(X_test)
rmse_test_cb = mean_squared_error(y_test, y_pred_test_cb, squared = False)
print("Best CatBoostRegressor model Test RMSE:", rmse_test_cb)
print("Best CatBoostRegressor model Test mean residual:", np.mean(y_test - y_pred_test_cb))
Best CatBoostRegressor model Test RMSE: 44.039583298055476 Best CatBoostRegressor model Test mean residual: 14.394662367664692
/Users/NOrlov/myenv/lib/python3.12/site-packages/sklearn/metrics/_regression.py:483: FutureWarning: 'squared' is deprecated in version 1.4 and will be removed in 1.6. To calculate the root mean squared error, use the function'root_mean_squared_error'.
# Визуализация y_true, y_pred, resid и их разницы для наглядности
fig = go.Figure()
fig.add_trace(go.Scatter(x=X_test.index, y=y_test,
mode='lines',
name='Original data'))
fig.add_trace(go.Scatter(x = X_test.index, y = y_pred_test_cb,
mode='lines',
name='CB prediction'))
fig.add_trace(go.Scatter(x = X_test.index, y = decomposed.resid.loc[X_test.index],
mode='lines',
name='Resid'))
fig.add_trace(go.Scatter(x = X_test.index, y = decomposed.resid.loc[X_test.index] + y_pred_test_cb.mean() - y_pred_test_cb,
mode='lines',
name='Resid-error'))
fig.update_layout(xaxis_tickformat='%Y-%m-%d', xaxis_title='Datetime')
fig.update_layout(width=1000, height=800)
fig.show()
# Построение гистограммы остатков на Test наборе
plt.figure(figsize = (16,8))
plt.hist(y_test - y_pred_test_cb, bins=100, label = 'model residual')
plt.hist(decomposed.resid.loc[X_test.index], bins=100, alpha = 0.7, label = 'residual component')
plt.xlabel('y_true - y_pred_test')
plt.ylabel('Frequency')
plt.axvline(x=0, color='r', linestyle='-', linewidth=1, label = 'zero')
plt.axvline(x=np.mean(y_test - y_pred_test_cb), color='b', linestyle='-', linewidth=2, alpha = 0.2, label = 'residual mean')
plt.legend()
plt.title('CatBoostRegressor residual hist')
plt.grid('true')
plt.show();
Выводы¶
Модель CatBoostRegressor с лучшими гиперпараметрами продемонстрировала на тестовой выборке метрику RMSE лучше порогового значения.
Анализ остатков показал нормальное распределение, однако остатки на тестовой выборке имеют ненулевое среднее, что говорит о недообученности модели. С другой стороны, остаточная компонента и остатки модели имеют большие коррелирующие между собой единичные выбросы, которые могут объяснять это смещение, особенно при небольшой тестовой выборке.
Выводы¶
Были изучены исторические данные сервиса заказа такси в аэропортах. Произведены предподготовка, трендовый, статистический и корреляционный анализы данных. Также был проведен feature engineering.
С целью прогнозирования объема спроса на такси в аэропорту были построены 2 модели с оптимизацией гиперпараметров, рассматривались CatBoosstRegressor и LightGBM.
LightGBM и CatBoostRegression продемонстрировали допустимые и схожие RMSE метрики - обе менее требуемых 48. Обе модели выглядят адекватными, однако CatBoostRegression демонстрирует быстродействие на 2 порядка лучше, поэтому рещено выбрать эту модель в качестве рабочей.